// Copyright (c) 2025 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.

// Package main provides a tool for training a zstd dictionary for compressing machine configuration YAML files and Kubernetes manifests.
package main

import (
	"bufio"
	"bytes"
	_ "embed"
	"fmt"
	"log"
	"math/rand/v2"
	"os"
	"path/filepath"
	"slices"
	"text/template"

	"github.com/klauspost/compress/dict"
	"github.com/klauspost/compress/zstd"
	helmclient "github.com/mittwald/go-helm-client"
	"go.uber.org/zap"
	"go.uber.org/zap/zapio"
	"go.yaml.in/yaml/v4"
	"helm.sh/helm/v3/pkg/chartutil"
	"helm.sh/helm/v3/pkg/repo"
)

const (
	argoCDChartKubeVersion = "1.31.0"
	argoCDChartName        = "argo-cd"
	argoCDChartNamespace   = "default"
	argoCDChartReleaseName = "my-argo-cd"
	argoCDChartRepo        = "https://argoproj.github.io/argo-helm"
	argoCDChartRepoName    = "argo"
	argoCDChartVersion     = "7.5.2"

	zstdDictID      = 1
	zstdMaxDictSize = 64 << 10 // 64KB
)

var (
	// machineConfigTemplate is a Go template for generating machine configuration YAML.
	//
	//go:embed data/machineconfig.tmpl.yaml
	machineConfigTemplate string

	// argoCDValues is a Helm values file for Argo CD.
	//
	//go:embed data/argocd-values.yaml
	argoCDValues string
)

func main() {
	if err := train(); err != nil {
		log.Fatalf("failed to train dictionary: %v", err)
	}
}

func getDestPath() (string, error) {
	cwd, err := os.Getwd()
	if err != nil {
		return "", err
	}

	// assume we are in the hack/zstd-dict directory
	dir := filepath.Clean(filepath.Join(cwd, "../../client/pkg/compression/data"))

	if _, err = os.Stat(dir); err != nil {
		// assume we are in the project root directory
		dir = filepath.Clean(filepath.Join(cwd, "client/pkg/compression/data"))

		if _, err = os.Stat(dir); err != nil {
			return "", err
		}
	}

	return filepath.Join(dir, fmt.Sprintf("config.%d.zdict", zstdDictID)), nil
}

func train() error {
	logger, err := zap.NewDevelopment()
	if err != nil {
		return err
	}

	logWriter := &zapio.Writer{
		Log: logger,
	}

	defer logWriter.Close() //nolint:errcheck

	machineConfigInputs, err := generateMachineConfigInputs(64, true)
	if err != nil {
		return err
	}

	machineConfigInputsNoComments, err := generateMachineConfigInputs(64, false)
	if err != nil {
		return err
	}

	// zstd doesn't like input files larger than 128KB, so we split the Argo CD manifests into chunks.
	argoCDInputs, err := generateArgoCDManifestInputs(112 << 10) // 112KB
	if err != nil {
		return err
	}

	inputs := slices.Concat(machineConfigInputs, machineConfigInputsNoComments, argoCDInputs)

	if err = saveInputs(inputs); err != nil {
		return err
	}

	zstdDict, err := dict.BuildZstdDict(inputs, dict.Options{
		ZstdDictID:  zstdDictID,
		MaxDictSize: zstdMaxDictSize, // 64KB
		HashBytes:   8,
		ZstdLevel:   zstd.SpeedDefault,
	})
	if err != nil {
		return err
	}

	destPath, err := getDestPath()
	if err != nil {
		return err
	}

	if err = os.WriteFile(destPath, zstdDict, 0o644); err != nil {
		return err
	}

	return nil
}

// saveInputs saves the input data to the "inputs" directory.
//
// The files in the directory can later be used to train a zstd dictionary via command line, e.g.:
// $ zstd --train -r inputs -o config-via-cli.zdict --dictID 1 --maxdict=64KB
//
// It can be used to compare the output with the one generated by this tool.
func saveInputs(inputs [][]byte) error {
	dir := "inputs"

	if err := os.RemoveAll(dir); err != nil {
		return err
	}

	if err := os.MkdirAll(dir, 0o755); err != nil {
		return err
	}

	for i, data := range inputs {
		fileName := fmt.Sprintf("input-%d.yaml", i)

		if err := os.WriteFile(filepath.Join(dir, fileName), data, 0o644); err != nil {
			return err
		}
	}

	return nil
}

func generateArgoCDManifestInputs(maxSizeBytes int) ([][]byte, error) {
	helmClient, err := helmclient.New(&helmclient.Options{})
	if err != nil {
		return nil, err
	}

	if err = helmClient.AddOrUpdateChartRepo(repo.Entry{
		Name: argoCDChartRepoName,
		URL:  argoCDChartRepo,
	}); err != nil {
		return nil, err
	}

	kubeVersion, err := chartutil.ParseKubeVersion(argoCDChartKubeVersion)
	if err != nil {
		return nil, err
	}

	argoCDManifests, err := helmClient.TemplateChart(
		&helmclient.ChartSpec{
			ReleaseName: argoCDChartReleaseName,
			ChartName:   argoCDChartRepoName + "/" + argoCDChartName,
			Namespace:   argoCDChartNamespace,
			ValuesYaml:  argoCDValues,
			Version:     argoCDChartVersion,
		}, &helmclient.HelmTemplateOptions{
			KubeVersion: kubeVersion,
			APIVersions: chartutil.DefaultVersionSet,
		})
	if err != nil {
		return nil, err
	}

	// split the manifests into chunks

	var result [][]byte

	scanner := bufio.NewScanner(bytes.NewReader(argoCDManifests))

	var buf bytes.Buffer

	for scanner.Scan() {
		line := scanner.Text()
		buf.WriteString(line)
		buf.WriteString("\n")

		if buf.Len() > maxSizeBytes {
			result = append(result, slices.Clone(buf.Bytes()))

			buf.Reset()
		}
	}

	if buf.Len() > 0 {
		result = append(result, slices.Clone(buf.Bytes()))
	}

	return result, nil
}

func generateMachineConfigInputs(num int, includeComments bool) ([][]byte, error) {
	inputs := make([][]byte, 0, num)

	for range num {
		data, err := randomMachineConfig()
		if err != nil {
			return nil, err
		}

		if !includeComments {
			data, err = removeYAMLComments(data)
			if err != nil {
				return nil, err
			}
		}

		inputs = append(inputs, data)
	}

	return inputs, nil
}

func removeYAMLComments(data []byte) ([]byte, error) {
	var m map[string]any

	if err := yaml.Unmarshal(data, &m); err != nil {
		return nil, err
	}

	return yaml.Marshal(&m)
}

func randomMachineConfig() ([]byte, error) {
	opts := randomTemplateOptions()

	tmpl, err := template.New("machineconfig").Parse(machineConfigTemplate)
	if err != nil {
		return nil, err
	}

	var buf bytes.Buffer

	if err = tmpl.Execute(&buf, opts); err != nil {
		return nil, err
	}

	return buf.Bytes(), nil
}

type templateOptions struct {
	MachineType   string // controlplane or worker
	MachineToken  string // example: 4llcpr.xfg2q2s0lshj2t40
	MachineCaCrt  string // example: LS0tLS1CRUdJTiBDRVJUSU...
	Domain        string // example: omni.omni-local.utkuozdemir.org
	Installer     string // installer or installer-secureboot
	SchematicID   string // example: 376567988ad370138ad8b2698212367b8edcb69b5fd68c80be1f2ec7d603b4ba
	ClusterID     string // example: gzlckavg-OUDujIaLx5PFDw17C4WrT9JL-_yjYoo1SY=
	ClusterSecret string // example: w5oJWDpnnju9CDA+BQY2jVKjhkVMmqU/oAa3S2Zj+OI=
	ClusterToken  string // example: o1s1dd.mhwu1453qbgoedgy
	ClusterCaCrt  string // example: LS0tLS1CRUdJTiBDR...
	JoinToken     string // example: w7uVuW3zaaaaaaaaacyetAHeYMeo5q2L9RvkAVfCfSCD
}

func randomTemplateOptions() templateOptions {
	return templateOptions{
		MachineType:   randomMachineType(),
		MachineToken:  randomString(32),
		MachineCaCrt:  randomString(1024),
		Domain:        randomDomain(),
		Installer:     randomInstaller(),
		SchematicID:   randomString(64),
		ClusterID:     randomString(64),
		ClusterSecret: randomString(64),
		ClusterToken:  randomString(32),
		ClusterCaCrt:  randomString(1024),
		JoinToken:     randomString(64),
	}
}

func randomString(n int) string {
	const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.=+/-_"

	b := make([]byte, n)

	for i := range b {
		b[i] = letters[rand.IntN(len(letters))]
	}

	return string(b)
}

var (
	tlds = []string{
		"com",
		"org",
		"net",
		"de",
		"ru",
	}
	machineTypes = []string{
		"controlplane",
		"worker",
	}
	installers = []string{
		"installer",
		"installer-secureboot",
	}
)

func randomDomain() string {
	return randomString(32) + "." + tlds[rand.IntN(len(tlds))]
}

func randomMachineType() string {
	return machineTypes[rand.IntN(len(machineTypes))]
}

func randomInstaller() string {
	return installers[rand.IntN(len(installers))]
}
